import re
import sys
import textwrap
import os
import unittest
from dataclasses import dataclass
from functools import cache
from test import support
from test.support.script_helper import run_python_until_end

_strace_binary = "/usr/bin/strace"
_syscall_regex = re.compile(
    r"(?P<syscall>[^(]*)\((?P<args>[^)]*)\)\s*[=]\s*(?P<returncode>.+)")
_returncode_regex = re.compile(
    br"\+\+\+ exited with (?P<returncode>\d+) \+\+\+")


@dataclass
class StraceEvent:
    syscall: str
    args: list[str]
    returncode: str


@dataclass
class StraceResult:
    strace_returncode: int
    python_returncode: int

    """The event messages generated by strace. This is very similar to the
    stderr strace produces with returncode marker section removed."""
    event_bytes: bytes
    stdout: bytes
    stderr: bytes

    def events(self):
        """Parse event_bytes data into system calls for easier processing.

        This assumes the program under inspection doesn't print any non-utf8
        strings which would mix into the strace output."""
        decoded_events = self.event_bytes.decode('utf-8', 'surrogateescape')
        matches = [
            _syscall_regex.match(event)
            for event in decoded_events.splitlines()
        ]
        return [
            StraceEvent(match["syscall"],
                        [arg.strip() for arg in (match["args"].split(","))],
                        match["returncode"]) for match in matches if match
        ]

    def sections(self):
        """Find all "MARK <X>" writes and use them to make groups of events.

        This is useful to avoid variable / overhead events, like those at
        interpreter startup or when opening a file so a test can verify just
        the small case under study."""
        current_section = "__startup"
        sections = {current_section: []}
        for event in self.events():
            if event.syscall == 'write' and len(
                    event.args) > 2 and event.args[1].startswith("\"MARK "):
                # Found a new section, don't include the write in the section
                # but all events until next mark should be in that section
                current_section = event.args[1].split(
                    " ", 1)[1].removesuffix('\\n"')
                if current_section not in sections:
                    sections[current_section] = list()
            else:
                sections[current_section].append(event)

        return sections

def _filter_memory_call(call):
    # mmap can operate on a fd or "MAP_ANONYMOUS" which gives a block of memory.
    # Ignore "MAP_ANONYMOUS + the "MAP_ANON" alias.
    if call.syscall == "mmap" and "MAP_ANON" in call.args[3]:
        return True

    if call.syscall in ("munmap", "mprotect"):
        return True

    return False


def filter_memory(syscalls):
    """Filter out memory allocation calls from File I/O calls.

    Some calls (mmap, munmap, etc) can be used on files or to just get a block
    of memory. Use this function to filter out the memory related calls from
    other calls."""

    return [call for call in syscalls if not _filter_memory_call(call)]


@support.requires_subprocess()
def strace_python(code, strace_flags, check=True):
    """Run strace and return the trace.

    Sets strace_returncode and python_returncode to `-1` on error."""
    res = None

    def _make_error(reason, details):
        return StraceResult(
            strace_returncode=-1,
            python_returncode=-1,
            event_bytes= f"error({reason},details={details!r}) = -1".encode('utf-8'),
            stdout=res.out if res else b"",
            stderr=res.err if res else b"")

    # Run strace, and get out the raw text
    try:
        res, cmd_line = run_python_until_end(
            "-c",
            textwrap.dedent(code),
            __run_using_command=[_strace_binary] + strace_flags,
        )
    except OSError as err:
        return _make_error("Caught OSError", err)

    if check and res.rc:
        res.fail(cmd_line)

    # Get out program returncode
    stripped = res.err.strip()
    output = stripped.rsplit(b"\n", 1)
    if len(output) != 2:
        return _make_error("Expected strace events and exit code line",
                           stripped[-50:])

    returncode_match = _returncode_regex.match(output[1])
    if not returncode_match:
        return _make_error("Expected to find returncode in last line.",
                           output[1][:50])

    python_returncode = int(returncode_match["returncode"])
    if check and python_returncode:
        res.fail(cmd_line)

    return StraceResult(strace_returncode=res.rc,
                        python_returncode=python_returncode,
                        event_bytes=output[0],
                        stdout=res.out,
                        stderr=res.err)


def get_events(code, strace_flags, prelude, cleanup):
    # NOTE: The flush is currently required to prevent the prints from getting
    # buffered and done all at once at exit
    prelude = textwrap.dedent(prelude)
    code = textwrap.dedent(code)
    cleanup = textwrap.dedent(cleanup)
    to_run = f"""
print("MARK prelude", flush=True)
{prelude}
print("MARK code", flush=True)
{code}
print("MARK cleanup", flush=True)
{cleanup}
print("MARK __shutdown", flush=True)
    """
    trace = strace_python(to_run, strace_flags)
    all_sections = trace.sections()
    return all_sections['code']


def get_syscalls(code, strace_flags, prelude="", cleanup="",
                 ignore_memory=True):
    """Get the syscalls which a given chunk of python code generates"""
    events = get_events(code, strace_flags, prelude=prelude, cleanup=cleanup)

    if ignore_memory:
        events = filter_memory(events)

    return [ev.syscall for ev in events]


# Moderately expensive (spawns a subprocess), so share results when possible.
@cache
def _can_strace():
    res = strace_python("import sys; sys.exit(0)",
                        # --trace option needs strace 5.5 (gh-133741)
                        ["--trace=%process"],
                        check=False)
    if res.strace_returncode == 0 and res.python_returncode == 0:
        assert res.events(), "Should have parsed multiple calls"
        return True
    return False


def requires_strace():
    if sys.platform != "linux":
        return unittest.skip("Linux only, requires strace.")

    if "LD_PRELOAD" in os.environ:
        # Distribution packaging (ex. Debian `fakeroot` and Gentoo `sandbox`)
        # use LD_PRELOAD to intercept system calls, which changes the overall
        # set of system calls which breaks tests expecting a specific set of
        # system calls).
        return unittest.skip("Not supported when LD_PRELOAD is intercepting system calls.")

    if support.check_sanitizer(address=True, memory=True):
        return unittest.skip("LeakSanitizer does not work under ptrace (strace, gdb, etc)")

    return unittest.skipUnless(_can_strace(), "Requires working strace")


__all__ = ["filter_memory", "get_events", "get_syscalls", "requires_strace",
           "strace_python", "StraceEvent", "StraceResult"]
